import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
import requests
from urllib.request import urlopen
import imutils
from model import ArtNet
import torch.nn.functional as nnf
import time

from utils import *


model = ArtNet(11)
model_path = "/home/shivam-wiz/Downloads/MLPR___/Trial/best_checkpoint.model"    
model.load_state_dict(torch.load(model_path))
model.eval()

path = "/home/shivam-wiz/Downloads/MLPR___/video.mp4"
cap = cv2.VideoCapture(path)

font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_color = (255, 255, 255)  # White color in BGR
font_thickness = 2
text_position1 = (50, 50)  # Adjust the position as needed
text_position2 = (50, 100)  # Adjust the position as needed

fourcc = cv2.VideoWriter_fourcc(*'MP4V')  
out = cv2.VideoWriter('output_video.mp4', fourcc, 20.0, (1200, 800))

rectangle_color = (0, 0, 0)  # Black color in BGR
rectangle_position = (40, 10)  # Rectangle position (x, y)
rectangle_size = (700, 100)  # Rectangle size (width, height)

# Check if the video file opened successfully
if not cap.isOpened():
    print("Error: Could not open video file.")
    exit()

def preprocess_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path)
    input_tensor = transformer(image)
    input_batch = input_tensor.unsqueeze(0)  # Add a batch dimension
    return input_batch

# Loop through each frame of the video
while True:
    # Capture frame-by-frame
    ret, frame = cap.read()
    frame = cv2.resize(frame, (1200, 800))

    # Check if the frame was captured successfully
    if not ret:
        print("End of video.")
        break
    
    cv2.imwrite("input_image.jpg", frame)

    with torch.no_grad():
        input_image = preprocess_image("input_image.jpg")
        output = model(input_image)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        predicted_class = torch.argmax(probabilities).item()

    # Assuming you have a list of class labels
    class_labels = classes
    predicted_label = class_description[class_labels[predicted_class]]
    # if probabilities[predicted_class] > 0.8:
    #     print(f"The model predicts: {predicted_label} with confidence: {probabilities[predicted_class]:.2%}")

    text1 = f"Predicted Class: {predicted_label}"
    text2 = f"Confidence: {probabilities[predicted_class]:.2%}"
    cv2.rectangle(frame, rectangle_position, (rectangle_position[0] + rectangle_size[0], rectangle_position[1] + rectangle_size[1]), rectangle_color, -1)
    cv2.putText(frame, text1, text_position1, font, font_scale, font_color, font_thickness)
    cv2.putText(frame, text2, text_position2, font, font_scale, font_color, font_thickness)
    cv2.imshow("Video Frame", frame)
    out.write(frame)

    # Break the loop if 'q' is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the video file and close all OpenCV windows
cap.release()
cv2.destroyAllWindows()
